

import math
import copy
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch.nn.init import normal_

from typing import Sequence
from einops import rearrange
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.runner.base_module import BaseModule
from mmcv.cnn.bricks.transformer import (
    BaseTransformerLayer,
    TransformerLayerSequence,
    build_transformer_layer_sequence
)
from mmcv.cnn import (
    build_activation_layer,
    build_conv_layer,
    build_norm_layer,
    xavier_init
)
from mmcv.cnn.bricks.registry import (
    ATTENTION,
    TRANSFORMER_LAYER,
    TRANSFORMER_LAYER_SEQUENCE
)
from mmcv.utils import (
    ConfigDict,
    build_from_cfg,
    deprecated_api_warning,
    to_2tuple
)
from mmdet.models.utils.builder import TRANSFORMER
from .ImageCrossAttention import ImageCrossAttention, ImageMSDeformableAttention3D
from .PointCloudCrossAttention import PointCloudCrossAttention, PointCloudMSDeformableAttention3D

@TRANSFORMER.register_module()
class AFSATransformer(BaseModule):
    """Implements the DETR transformer.
    Args:
        encoder (`mmcv.ConfigDict` | Dict): Config of
            TransformerEncoder. Defaults to None.
        decoder ((`mmcv.ConfigDict` | Dict)): Config of
            TransformerDecoder. Defaults to None
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Defaults to None.
    """

    def __init__(self,
                 num_cams=6,
                 num_feature_levels=2, 
                 queue_length=3,
                 encoder=None, 
                 decoder=None, 
                 init_cfg=None, 
                 cross=False):
        super(AFSATransformer, self).__init__(init_cfg=init_cfg)
        self.num_cams = num_cams
        self.num_feature_levels = num_feature_levels
        self.queue_length = queue_length
        if encoder is not None:
            self.encoder = build_transformer_layer_sequence(encoder)
        else:
            self.encoder = None
        self.decoder = build_transformer_layer_sequence(decoder)
        self.embed_dims = self.decoder.embed_dims
        self.cross = cross
        self.cams_embeds = nn.Parameter(torch.Tensor(self.num_cams, self.embed_dims))
        self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims))
        self.queue_embeds = nn.Parameter(torch.Tensor(self.queue_length, self.embed_dims))

    def init_weights(self):

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        for m in self.modules():
            if isinstance(m, ImageCrossAttention) or isinstance(m, PointCloudCrossAttention) \
                 or isinstance(m, PointCloudMSDeformableAttention3D) or isinstance(m, ImageMSDeformableAttention3D) :
                try:
                    m.init_weight()
                except AttributeError:
                    m.init_weights()
        normal_(self.level_embeds)
        normal_(self.cams_embeds)
        normal_(self.queue_embeds)

    def forward(self, x, x_img, query_pos_embeds, query, reference_points, img_metas, attn_masks=None):
        """Forward function for `Transformer`.
        Args:
            x (Tensor): Input point cloud.
            x (Tensor): Input image.
            query(Tensor): Input query.
            reference_points (Tensor): The reference points of offset.
            query_pos_embed (Tensor): The positional encoding for query .
        Returns:
            tuple[Tensor]: results of decoder containing the following tensor.
                - out_dec: Output from decoder. If return_intermediate_dec \
                      is True output has shape [num_dec_layers, bs,
                      num_query, embed_dims], else has shape [1, bs, \
                      num_query, embed_dims].
                - memory: Output results from encoder].
        """
        # image
        img_feat_flatten = []
        img_spatial_shapes = []
        for lvl, feat in enumerate(x_img):
            bs, queue_len, num_cam, c, h, w = feat.shape
            spatial_shape = (h, w)
            feat = feat.flatten(4).permute(0, 1, 2, 4, 3)
            feat = feat + self.cams_embeds[None, None, :, None, :].to(feat.dtype)
            feat = feat + self.level_embeds[None, None, None, lvl:lvl + 1, :].to(feat.dtype)
            feat = feat + self.queue_embeds[None, :, None, None, :].to(feat.dtype)
            img_spatial_shapes.append(spatial_shape)
            img_feat_flatten.append(feat)
        img_feat_flatten = torch.cat(img_feat_flatten, dim=3)
        img_spatial_shapes = torch.as_tensor(
        img_spatial_shapes, dtype=torch.long, device=query.device)
        img_level_start_index = torch.cat((img_spatial_shapes.new_zeros(
            (1,)), img_spatial_shapes.prod(1).cumsum(0)[:-1]))
        
        # point cloud
        pts_feat_flatten = []
        pts_spatial_shapes = []
        for lvl, feat in enumerate(x):
            bs, queue_len, c, h, w = feat.shape
            spatial_shape = (h, w)
            feat = feat.flatten(3).permute(0, 1, 3, 2)
            feat = feat + self.queue_embeds[None, :, None, :].to(feat.dtype)
            pts_spatial_shapes.append(spatial_shape)
            pts_feat_flatten.append(feat)
        pts_feat_flatten = torch.cat(pts_feat_flatten, 3)
        pts_spatial_shapes = torch.as_tensor(
        pts_spatial_shapes, dtype=torch.long, device=query.device)
        pts_level_start_index = torch.cat((img_spatial_shapes.new_zeros(
            (1,)), img_spatial_shapes.prod(1).cumsum(0)[:-1]))

        # out_dec: [num_layers, bs, num_query, dim]
        out_dec = self.decoder(
            query=query,
            query_pos=query_pos_embeds,
            reference_points=reference_points,
            img_feat_flatten=img_feat_flatten,
            pts_feat_flatten=pts_feat_flatten,
            img_spatial_shapes=img_spatial_shapes,
            img_level_start_index=img_level_start_index,
            pts_spatial_shapes=pts_spatial_shapes,
            pts_level_start_index=pts_level_start_index,
            img_metas=img_metas,
            key_padding_mask=None,
            attn_masks=[attn_masks, None, None],
            reg_branch=None,
            )

        return  out_dec


@TRANSFORMER.register_module()
class AFSALidarTransformer(BaseModule):
    """Implements the DETR transformer.
 
    Args:
        encoder (`mmcv.ConfigDict` | Dict): Config of
            TransformerEncoder. Defaults to None.
        decoder ((`mmcv.ConfigDict` | Dict)): Config of
            TransformerDecoder. Defaults to None
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Defaults to None.
    """

    def __init__(self, encoder=None, decoder=None, init_cfg=None, cross=False):
        super(AFSALidarTransformer, self).__init__(init_cfg=init_cfg)
        if encoder is not None:
            self.encoder = build_transformer_layer_sequence(encoder)
        else:
            self.encoder = None
        self.decoder = build_transformer_layer_sequence(decoder)
        self.embed_dims = self.decoder.embed_dims
        self.cross = cross

    def init_weights(self):
        # follow the official DETR to init parameters
        for m in self.modules():
            if hasattr(m, 'weight') and m.weight.dim() > 1:
                xavier_init(m, distribution='uniform')
        self._is_init = True

    def forward(self, x, mask, query_embed, pos_embed, attn_masks=None, reg_branch=None):

        bs, c, h, w = x.shape
        memory = rearrange(x, "bs c h w -> (h w) bs c") # [bs, n, c, h, w] -> [n*h*w, bs, c]
        pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) # [bs, n, c, h, w] -> [n*h*w, bs, c]
        query_embed = query_embed.transpose(0, 1)  # [num_query, dim] -> [num_query, bs, dim]
        mask = mask.view(bs, -1)  # [bs, n, h, w] -> [bs, n*h*w]
        target = torch.zeros_like(query_embed)
        # out_dec: [num_layers, num_query, bs, dim]
        out_dec = self.decoder(
            query=target,
            key=memory,
            value=memory,
            key_pos=pos_embed,
            query_pos=query_embed,
            key_padding_mask=mask,
            attn_masks=[attn_masks, None],
            reg_branch=reg_branch,
            )
        out_dec = out_dec.transpose(1, 2)
        return  out_dec, memory


@TRANSFORMER.register_module()
class AFSAImageTransformer(BaseModule):
    """Implements the DETR transformer.

    Args:
        encoder (`mmcv.ConfigDict` | Dict): Config of
            TransformerEncoder. Defaults to None.
        decoder ((`mmcv.ConfigDict` | Dict)): Config of
            TransformerDecoder. Defaults to None
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Defaults to None.
    """

    def __init__(self, encoder=None, decoder=None, init_cfg=None, cross=False):
        super(CmtImageTransformer, self).__init__(init_cfg=init_cfg)
        if encoder is not None:
            self.encoder = build_transformer_layer_sequence(encoder)
        else:
            self.encoder = None
        self.decoder = build_transformer_layer_sequence(decoder)
        self.embed_dims = self.decoder.embed_dims
        self.cross = cross

    def init_weights(self):
        # follow the official DETR to init parameters
        for m in self.modules():
            if hasattr(m, 'weight') and m.weight.dim() > 1:
                xavier_init(m, distribution='uniform')
        self._is_init = True

    def forward(self, x_img, query_embed, rv_pos_embed, attn_masks=None, reg_branch=None, bs=2):
       
        memory = rearrange(x_img, "(bs v) c h w -> (v h w) bs c", bs=bs)
        pos_embed = rearrange(rv_pos_embed, "(bs v) h w c -> (v h w) bs c", bs=bs)
        
        query_embed = query_embed.transpose(0, 1)  # [num_query, dim] -> [num_query, bs, dim]
        mask =  memory.new_zeros(bs, memory.shape[0]) # [bs, n, h, w] -> [bs, n*h*w]

        target = torch.zeros_like(query_embed)
        # out_dec: [num_layers, num_query, bs, dim]
        out_dec = self.decoder(
            query=target,
            key=memory,
            value=memory,
            key_pos=pos_embed,
            query_pos=query_embed,
            key_padding_mask=mask,
            attn_masks=[attn_masks, None],
            reg_branch=reg_branch,
            )
        out_dec = out_dec.transpose(1, 2)
        return  out_dec, memory
